# -*- coding: utf-8 -*-
"""
Created on Thu Jun 18 19:09:10 2020

@author: Yixuan Pan (ypan1003@umd.edu)
"""

import pandas as pd
import os
import gc
import warnings
warnings.filterwarnings('ignore')
import time


# Function to restrict the percentage reduction to 0-100 %
def reasonable_check(series):
    series = series.apply(lambda x: max(0, x))
    series = series.apply(lambda x: min(100, x))
    return series

# Function to calculate the daily Social Distancing Index (SDI)
def social_distancing_index(current_metrics, benchmark_metrics,
                            beta = [1, 0.25, 0.45, 0.3, 0.2]):
    # Set the county FIPS or state FIPS as the index of dataframe
    try:
        current = current_metrics.set_index('CTFIPS')
        benchmark = benchmark_metrics.set_index('CTFIPS')
    except:
        current = current_metrics.set_index('STFIPS')
        benchmark = benchmark_metrics.set_index('STFIPS')
    # Calculate the variables specified in the Social Distancing Index (SDI).
    X1 = current['% staying home']
    X2 = (benchmark['#work trips/person'] - current['#work trips/person']).div(
        benchmark['#work trips/person'])*100
    X3 = (benchmark['#non-work trips/person'] - current['#non-work trips/person']).div(
        benchmark['#non-work trips/person'])*100
    X4 = (benchmark['miles traveled/person'] - current['miles traveled/person']).div(
        benchmark['miles traveled/person'])*100
    X5 = (benchmark['#out-of-county trips'] - current['#out-of-county trips']).div(
        benchmark['#out-of-county trips'])*100
    X =pd.concat([X1,X2,X3,X4,X5], axis = 1)
    X.columns = ['X1','X2','X3','X4','X5']
    
    # Restrict the percentage reduction to 0-100 %
    X.iloc[:,1:]= X.iloc[:,1:].apply(lambda series: reasonable_check(series), axis = 1)
    
    # Calculate the SDI
    sd_index = (X['X1']*beta[0] + (1 - X['X1']/100).multiply(
        X['X2']*beta[1] + X['X3']*beta[2] + X['X4']*beta[3]))*(1-beta[4]) + X['X5']*beta[4]
    sd_index = sd_index.rename('social distancing index')
    current_metrics_sdi = pd.concat([current, sd_index], axis = 1)
    return current_metrics_sdi.reset_index()

# Function to calculate SDI scores for all the days
def main(level):
    if level == 'county':
        current_metrics = pd.read_csv('%s_mobility_metrics.csv' % (level), 
                                      dtype = {'STFIPS': str, 'CTFIPS': str})
        benchmark_metrics = pd.read_csv('%s_benchmark_metrics.csv' % (level), 
                                      dtype = {'STFIPS': str, 'CTFIPS': str})
    elif level == 'state':
        current_metrics = pd.read_csv('%s_mobility_metrics.csv' % (level), dtype = {'STFIPS': str})
        benchmark_metrics = pd.read_csv('%s_benchmark_metrics.csv' % (level), dtype = {'STFIPS': str})
    else:
        print('Wrong level specified.')
        
    # Calculate the SDI by day
    metrics_sdi = pd.DataFrame()
    for tmp_date in current_metrics['date'].unique():
        tmp_metrics = current_metrics.loc[current_metrics['date'] == tmp_date].sort_values(
            'STFIPS').reset_index(drop = True)
        tmp_metrics_sdi = social_distancing_index(tmp_metrics, benchmark_metrics)
        metrics_sdi = metrics_sdi.append(tmp_metrics_sdi, ignore_index = True)
    return metrics_sdi

if __name__ == "__main__":
    county_metrics_sdi = main(level = 'county')
    state_metrics_sdi = main(level = 'state')
